package com.luis.toolsuite.controller;

import com.luis.toolsuite.isolationforest.IForest;
import com.luis.toolsuite.isolationforest.model.IfDataPoint;
import com.luis.toolsuite.service.IfDataRepository;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Controller
public class IsForestController {

	@Autowired
	private IfDataRepository dataService;

	@RequestMapping("/isforest/data")
	@ResponseBody
	public Map<String,Object> data() {
		Map<String,Object> resultMap = new LinkedHashMap<>();
		IForest forest = new IForest();
		//获取用于训练的数据
		List<IfDataPoint> trainList = dataService.findForTrain();
		double[][] trainSamples = this.transformFromDataListToTrain(trainList);
		//森林中树的数量
		long trainStart = System.currentTimeMillis();
		int subTreeNum = 100;
		forest.train(trainSamples, subTreeNum);
		long trainEnd = System.currentTimeMillis();
		//获取用于验证的数据，第一列是时间,用于结果展示
		List<IfDataPoint> testList = dataService.findForTest();
		double[][] testSamples = this.transformFromDataListToTest(testList);
		List<double[]> resultList = new ArrayList<>();
		long calcStart = System.currentTimeMillis();
		for(int i =0;i< testSamples.length;i++){
			double[] result = new double[9];
			double[] temp = new double[]{testSamples[i][1],testSamples[i][2]};
			int index = 0;
			result[index++] = testSamples[i][0];
			result[index++] = forest.computeAnomalyScore(temp);

			result[index++] = testSamples[i][index];
			result[index++] = testSamples[i][index];
			result[index++] = testSamples[i][index];
			result[index++] = testSamples[i][index];
			result[index++] = testSamples[i][index];
			result[index++] = testSamples[i][index];
			result[index++] = testSamples[i][index];

			resultList.add(result);
		}
		long calcEnd = System.currentTimeMillis();

		resultMap.put("data",resultList);
		resultMap.put("train_size",trainSamples.length);
		resultMap.put("train_cost",(trainEnd - trainStart));
		resultMap.put("calc_size",testSamples.length);
		resultMap.put("calc_cost",(calcEnd - calcStart));

		return resultMap;
	}

	private double[][] transformFromDataListToTrain(List<IfDataPoint> dataPointList){
		//过滤掉离群数据
		List<IfDataPoint> validList = dataPointList.stream().filter(
				t-> Double.parseDouble(t.getValue2()) > -20
		 		&& Double.parseDouble(t.getValue2()) < 20 ).collect(Collectors.toList());
		double[][] result = new double[validList.size()][2];
		for(int i=0;i< validList.size();i++){
			result[i][0] = Double.parseDouble(dataPointList.get(i).getValue1());
			result[i][1] = Double.parseDouble(dataPointList.get(i).getValue2());
		}
		return result;
	}

	private double[][] transformFromDataListToTest(List<IfDataPoint> dataPointList){
		double[][] result = new double[dataPointList.size()][10];
		for(int i=0;i< dataPointList.size();i++){
			result[i][0] = dataPointList.get(i).getTime().getTime();
			result[i][1] = Double.parseDouble(dataPointList.get(i).getValue1());
			result[i][2] = Double.parseDouble(dataPointList.get(i).getValue2());
			//数据补充信息
			result[i][3] = Double.parseDouble(dataPointList.get(i).getRise());
			result[i][4] = Double.parseDouble(dataPointList.get(i).getLastclose());
			result[i][5] = Double.parseDouble(dataPointList.get(i).getOpen());
			result[i][6] = Double.parseDouble(dataPointList.get(i).getClose());
			result[i][7] = Double.parseDouble(dataPointList.get(i).getHigh());
			result[i][8] = Double.parseDouble(dataPointList.get(i).getLow());
			result[i][9] = Double.parseDouble(dataPointList.get(i).getHuanshoulv());
		}
		return result;
	}

}
